#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


"""
msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter 模块的单元测试
"""

import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
import torch
import torch.nn as nn

from ascend_utils.common.utils import CallParams
from ascend_utils.core.dag.dag_node import DagNode
from msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook import DagTorchHook
from msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_infos import ModuleType
from msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process import StructureProcess
from msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter import DagNodeInfo, TorchDAGAdapter


class TestDagNodeInfo(unittest.TestCase):
    """测试DagNodeInfo类"""

    def setUp(self):
        """测试前准备"""
        self.dag_node_info = DagNodeInfo(
            name="test_node",
            class_type="Linear",
            input_nodes=["input1", "input2"],
            output_nodes=["output1", "output2"]
        )

    def test_init(self):
        """测试初始化"""
        self.assertEqual(self.dag_node_info.name, "test_node")
        self.assertEqual(self.dag_node_info.class_type, "Linear")
        self.assertEqual(self.dag_node_info.input_nodes, ["input1", "input2"])
        self.assertEqual(self.dag_node_info.output_nodes, ["output1", "output2"])

    def test_eq_true(self):
        """测试相等比较"""
        other = DagNodeInfo(
            name="test_node",
            class_type="Linear",
            input_nodes=["input1", "input2"],
            output_nodes=["output1", "output2"]
        )
        self.assertEqual(self.dag_node_info, other)

    def test_eq_false_different_name(self):
        """测试名称不同"""
        other = DagNodeInfo(
            name="other_node",
            class_type="Linear",
            input_nodes=["input1", "input2"],
            output_nodes=["output1", "output2"]
        )
        self.assertNotEqual(self.dag_node_info, other)

    def test_eq_false_different_class_type(self):
        """测试类型不同"""
        other = DagNodeInfo(
            name="test_node",
            class_type="Conv2d",
            input_nodes=["input1", "input2"],
            output_nodes=["output1", "output2"]
        )
        self.assertNotEqual(self.dag_node_info, other)

    def test_eq_false_different_input_nodes(self):
        """测试输入节点不同"""
        other = DagNodeInfo(
            name="test_node",
            class_type="Linear",
            input_nodes=["input1", "input3"],
            output_nodes=["output1", "output2"]
        )
        self.assertNotEqual(self.dag_node_info, other)

    def test_eq_false_different_output_nodes(self):
        """测试输出节点不同"""
        other = DagNodeInfo(
            name="test_node",
            class_type="Linear",
            input_nodes=["input1", "input2"],
            output_nodes=["output1", "output3"]
        )
        self.assertNotEqual(self.dag_node_info, other)

    def test_eq_none(self):
        """测试与None比较"""
        self.assertIsNotNone(self.dag_node_info)


class MockDagNode:
    def __init__(self, name, op_type, input_nodes=None, output_nodes=None, inputs=None, outputs=None):
        self.name_in_network = name
        self.op_type = op_type
        self.input_nodes = input_nodes or []
        self.output_nodes = output_nodes or []
        self.inputs = inputs or []
        self.outputs = outputs or []


class TestTorchDAGAdapter(unittest.TestCase):
    """测试TorchDAGAdapter类"""

    def setUp(self):
        """测试前准备"""
        self.model = nn.Linear(224 * 224 * 3, 10)

        self.dag_nodes = [
            MockDagNode("layer1", "Linear"),
            MockDagNode("relu", "ReLU"),
            MockDagNode("layer2", "Linear"),
            MockDagNode("add_op", "__add__"),
            MockDagNode("norm", "LayerNorm")
        ]

        self.dag_nodes[0].output_nodes = [self.dag_nodes[1]]
        self.dag_nodes[1].input_nodes = [self.dag_nodes[0]]
        self.dag_nodes[1].output_nodes = [self.dag_nodes[2]]
        self.dag_nodes[2].input_nodes = [self.dag_nodes[1]]

        self.mock_dag_hook = Mock()
        self.mock_dag_hook.dag_node_list = self.dag_nodes

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=self.mock_dag_hook):
            self.adapter = TorchDAGAdapter(
                model=self.model,
                dummy_input=torch.randn(1, 224 * 224 * 3),
                hook_nodes=[nn.LayerNorm]
            )

    def test_init_with_default_dummy_input(self):
        """测试使用默认dummy_input初始化"""
        default_model = nn.Linear(224 * 224 * 3, 10)
        mock_dag_hook = Mock()
        mock_dag_hook.dag_node_list = [MockDagNode("linear", "Linear")]

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=mock_dag_hook):
            adapter = TorchDAGAdapter(default_model, dummy_input=torch.randn(1, 224 * 224 * 3))
            self.assertIsNotNone(adapter._dummy_input)
            self.assertEqual(adapter._dummy_input.shape, torch.Size([1, 224 * 224 * 3]))

    def test_init_with_tensor_dummy_input(self):
        """测试使用tensor dummy_input初始化"""
        dummy_input = torch.randn(2, 10)
        model = nn.Linear(10, 20)
        mock_dag_hook = Mock()
        mock_dag_hook.dag_node_list = [MockDagNode("linear", "Linear")]

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=mock_dag_hook):
            adapter = TorchDAGAdapter(model, dummy_input=dummy_input)
            self.assertTrue(torch.equal(adapter._dummy_input, dummy_input))

    def test_init_with_hook_nodes(self):
        """测试使用hook_nodes初始化"""
        model = nn.Linear(10, 20)
        mock_dag_hook = Mock()
        mock_dag_hook.dag_node_list = [MockDagNode("linear", "Linear")]

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=mock_dag_hook):
            adapter = TorchDAGAdapter(model, dummy_input=torch.randn(1, 10), hook_nodes=[nn.LayerNorm, nn.BatchNorm2d])
            self.assertIn("layernorm", adapter.norm_nodes)
            self.assertIn("batchnorm2d", adapter.norm_nodes)

    def test_init_with_none_hook_nodes(self):
        """测试使用None hook_nodes初始化"""
        model = nn.Linear(10, 20)
        mock_dag_hook = Mock()
        mock_dag_hook.dag_node_list = [MockDagNode("linear", "Linear")]

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=mock_dag_hook):
            adapter = TorchDAGAdapter(model, dummy_input=torch.randn(1, 10), hook_nodes=None)
            self.assertEqual(adapter.norm_nodes, [])

    def test_get_node_name_in_order(self):
        """测试获取节点名称"""
        with patch.object(self.adapter, 'node_list', self.dag_nodes):
            names = self.adapter.get_node_name_in_order()
            expected = [node.name_in_network for node in self.dag_nodes]
            self.assertEqual(names, expected)

    def test_get_name_type_dict_in_order(self):
        """测试获取名称类型字典"""
        with patch.object(self.adapter, 'node_list', self.dag_nodes):
            result = self.adapter.get_name_type_dict_in_order()
            expected = {node.name_in_network: node.op_type for node in self.dag_nodes}
            self.assertEqual(result, expected)

    def test_get_network_topology(self):
        """测试获取网络拓扑"""
        with patch.object(self.adapter, 'node_list', self.dag_nodes):
            topology = self.adapter.get_network_topology()
            self.assertEqual(len(topology), len(self.dag_nodes))

            first_node = topology[0]
            self.assertEqual(first_node.name, self.dag_nodes[0].name_in_network)
            self.assertEqual(first_node.class_type, self.dag_nodes[0].op_type)
            self.assertEqual(first_node.input_nodes,
                             [item.name_in_network if item else None for item in self.dag_nodes[0].input_nodes])
            self.assertEqual(first_node.output_nodes,
                             [item.name_in_network if item else None for item in self.dag_nodes[0].output_nodes])

    def test_get_mhsa_pattern_with_valid_conditions(self):
        """测试获取MHSA模式，满足条件的情况"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]

        test_nodes = [mock_node, mock_add_node, mock_other_branch, MockDagNode("linear1", "Linear"),
                      MockDagNode("linear2", "Linear")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                        as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_called()
                self.assertIsInstance(qkv_list, list)
                self.assertIsInstance(proj_list, list)

    def test_get_mhsa_pattern_with_continue_condition_len_add_node_not_1(self):
        """测试MHSA模式中满足continue条件：add_node数量不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node1 = MockDagNode("add1", "__add__")
        mock_add_node2 = MockDagNode("add2", "__add__")

        mock_node.output_nodes = [mock_add_node1, mock_add_node2]

        test_nodes = [mock_node, mock_add_node1, mock_add_node2]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                        as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()
                self.assertIsInstance(qkv_list, list)
                self.assertIsInstance(proj_list, list)

    def test_get_mhsa_pattern_with_continue_condition_len_add_inputs_not_2(self):
        """测试MHSA模式中满足continue条件：add_node输入数量不为2"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock()]
        mock_add_node.input_nodes = [mock_node]

        mock_other_branch = MockDagNode("other", "Linear")
        mock_node.output_nodes = [mock_add_node, mock_other_branch]

        test_nodes = [mock_node, mock_add_node, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                        as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()
                self.assertIsInstance(qkv_list, list)
                self.assertIsInstance(proj_list, list)

    def test_get_mhsa_pattern_with_continue_condition_no_other_branch(self):
        """测试MHSA模式中满足continue条件：没有other_branch"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        mock_node.output_nodes = [mock_add_node]

        test_nodes = [mock_node, mock_add_node, MockDagNode("other_input", "Linear")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                        as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()
                self.assertIsInstance(qkv_list, list)
                self.assertIsInstance(proj_list, list)

    def test_get_ffn_pattern_with_valid_conditions(self):
        """测试获取FFN模式，满足条件的情况"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        test_nodes = [mock_node, mock_add_node, mock_other_branch, MockDagNode("linear1", "Linear"),
                      MockDagNode("linear2", "Linear")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                       'StructureProcess.is_ffn_matmul',
                       return_value=True):
                ffn_pattern = self.adapter.get_ffn_pattern()
                self.assertIsInstance(ffn_pattern, list)

    def test_get_ffn_pattern_with_continue_condition_len_add_node_not_1(self):
        """测试FFN模式中满足continue条件：add_node数量不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node1 = MockDagNode("add1", "__add__")
        mock_add_node2 = MockDagNode("add2", "__add__")
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node1, mock_add_node2, mock_other_branch]

        test_nodes = [mock_node, mock_add_node1, mock_add_node2, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.is_ffn_matmul') \
                        as mock_is_ffn:
                ffn_pattern = self.adapter.get_ffn_pattern()
                mock_is_ffn.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)

    def test_get_ffn_pattern_with_continue_condition_len_other_branch_not_1(self):
        """测试FFN模式中满足continue条件：other_branch数量不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch1 = MockDagNode("other1", "Linear")
        mock_other_branch2 = MockDagNode("other2", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch1, mock_other_branch2]

        test_nodes = [mock_node, mock_add_node, mock_other_branch1, mock_other_branch2]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.is_ffn_matmul') \
                        as mock_is_ffn:
                ffn_pattern = self.adapter.get_ffn_pattern()
                mock_is_ffn.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)

    def test_get_ffn_pattern_with_continue_condition_len_add_inputs_not_2(self):
        """测试FFN模式中满足continue条件：add_node输入数量不为2"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]

        test_nodes = [mock_node, mock_add_node, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.is_ffn_matmul') \
                        as mock_is_ffn:
                ffn_pattern = self.adapter.get_ffn_pattern()
                mock_is_ffn.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)

    def test_get_mhsa_ln_pattern(self):
        """测试获取MHSA LN模式"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        test_nodes = [mock_node, mock_add_node, mock_other_branch, MockDagNode("linear1", "Linear"),
                      MockDagNode("linear2", "Linear"), MockDagNode("norm", "LayerNorm")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                       'StructureProcess.mhsa_matmul_ln_process'):
                qkv_list, proj_list, ln_list = self.adapter.get_mhsa_ln_pattern()
                self.assertIsInstance(qkv_list, list)
                self.assertIsInstance(proj_list, list)
                self.assertIsInstance(ln_list, list)

    def test_get_ffn_ln_pattern(self):
        """测试获取FFN LN模式"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        test_nodes = [mock_node, mock_add_node, mock_other_branch, MockDagNode("linear1", "Linear"),
                      MockDagNode("linear2", "Linear"), MockDagNode("norm", "LayerNorm")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                       'TorchDAGAdapter._get_ffn_pattern_and_ln_list'):
                ffn_pattern, ffn_ln_list = self.adapter.get_ffn_ln_pattern()
                self.assertIsInstance(ffn_pattern, list)
                self.assertIsInstance(ffn_ln_list, list)

    def test_get_llama_mhsa_ln_pattern(self):
        """测试获取LLAMA MHSA LN模式"""
        with patch.object(self.adapter, 'get_mhsa_ln_pattern') as mock_get:
            self.adapter.get_llama_mhsa_ln_pattern()
            mock_get.assert_called_once_with(ln_type='Llamarmsnormbias')

    def test_get_llama_ffn_ln_pattern_with_valid_conditions(self):
        """测试获取LLAMA FFN LN模式，满足条件的情况"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        test_nodes = [mock_node, mock_add_node, mock_other_branch, MockDagNode("linear1", "Linear"),
                      MockDagNode("linear2", "Linear"), MockDagNode("norm", "LayerNorm")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                       'TorchDAGAdapter._get_ffn_pattern_and_ln_list'):
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                self.assertIsInstance(ffn_pattern, list)
                self.assertIsInstance(ffn_ln_list, list)

    def test_get_llama_ffn_ln_pattern_with_continue_condition_len_add_node_not_1(self):
        """测试LLAMA FFN LN模式中满足continue条件：add_node数量不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node1 = MockDagNode("add1", "__add__")
        mock_add_node2 = MockDagNode("add2", "__add__")
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node1, mock_add_node2, mock_other_branch]

        test_nodes = [mock_node, mock_add_node1, mock_add_node2, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                        as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)
                self.assertIsInstance(ffn_ln_list, list)

    def test_get_llama_ffn_ln_pattern_with_continue_condition_len_other_branch_not_1(self):
        """测试LLAMA FFN LN模式中满足continue条件：other_branch数量不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]
        mock_other_branch1 = MockDagNode("other1", "Linear")
        mock_other_branch2 = MockDagNode("other2", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch1, mock_other_branch2]

        test_nodes = [mock_node, mock_add_node, mock_other_branch1, mock_other_branch2]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                        as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)
                self.assertIsInstance(ffn_ln_list, list)

    def test_get_llama_ffn_ln_pattern_with_continue_condition_len_add_inputs_not_2(self):
        """测试LLAMA FFN LN模式中满足continue条件：add_node输入数量不为2"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock()]
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]

        test_nodes = [mock_node, mock_add_node, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                        as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()
                self.assertIsInstance(ffn_pattern, list)
                self.assertIsInstance(ffn_ln_list, list)

    def test_get_norm_linear_subgraph(self):
        """测试获取norm-linear子图"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, linear3, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                result = self.adapter.get_norm_linear_subgraph()
                self.assertIsInstance(result, defaultdict)
                self.assertIn(norm_node.name_in_network, result)

    def test_get_kv_linears(self):
        """测试获取KV线性层"""
        # 创建包含norm节点的测试数据
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, linear3, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                kv_linears, num_kv = self.adapter.get_kv_linears()
                self.assertIsInstance(kv_linears, list)
                self.assertIsInstance(num_kv, int)

    def test_get_linear_linear_subgraph(self):
        """测试获取linear-linear子图"""
        # 创建包含norm节点的测试数据
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name', return_value=Mock()):
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)

    def test_get_allreduce_linear(self):
        """测试获取allreduce线性层"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                result = self.adapter.get_allreduce_linear()
                self.assertIsInstance(result, list)

    def test_find_module_by_name(self):
        """测试通过名称查找模块"""
        # 测试正常情况
        result = self.adapter._find_module_by_name("model.0")
        # 由于模型结构，这里会返回Linear层或None

        # 测试None输入
        result = self.adapter._find_module_by_name(None)
        self.assertIsNone(result)

    def test_split_by_add_op_type(self):
        """测试按add操作类型分割"""
        node = MockDagNode("test", "Linear")
        add_node = MockDagNode("__add__", "__add__")
        other_node = MockDagNode("other", "ReLU")

        node.output_nodes = [add_node, other_node]

        add_nodes, other_branch = self.adapter._split_by_add_op_type(node)
        self.assertEqual(len(add_nodes), 1)
        self.assertEqual(len(other_branch), 1)
        self.assertEqual(add_nodes[0].op_type, "__add__")
        self.assertEqual(other_branch[0].op_type, "ReLU")

    def test_find_input_nodes(self):
        """测试查找输入节点"""
        # 创建一些节点
        input_node = MockDagNode("input", "Input")
        tensor_node = MockDagNode("_TensorBase.size", "_TensorBase.size")  # 应该被过滤
        normal_node = MockDagNode("normal", "Linear")

        # 模拟dag_node_from属性
        input_node.inputs = [Mock()]
        input_node.inputs[0].dag_node_from = None  # 没有来源，是输入

        tensor_node.inputs = [Mock()]
        tensor_node.inputs[0].dag_node_from = Mock()

        normal_node.inputs = [Mock()]
        normal_node.inputs[0].dag_node_from = None

        with patch.object(self.adapter, 'node_list', [input_node, tensor_node, normal_node]):
            input_nodes = self.adapter._find_input_nodes()
            self.assertIn(input_node, input_nodes)
            self.assertNotIn(tensor_node, input_nodes)  # 被dark_name_list过滤
            self.assertIn(normal_node, input_nodes)

    def test_find_output_nodes(self):
        """测试查找输出节点"""
        output_node = MockDagNode("output", "Output")
        tensor_node = MockDagNode("_TensorBase.size", "_TensorBase.size")  # 应该被过滤
        normal_node = MockDagNode("normal", "Linear")

        # 模拟outputs属性
        output_node.outputs = [Mock()]
        output_node.outputs[0].dag_nodes_to = []  # 没有去向，是输出

        tensor_node.outputs = [Mock()]
        tensor_node.outputs[0].dag_nodes_to = [Mock()]

        normal_node.outputs = [Mock()]
        normal_node.outputs[0].dag_nodes_to = []  # 没有去向，是输出

        with patch.object(self.adapter, 'node_list', [output_node, tensor_node, normal_node]):
            output_nodes = self.adapter._find_output_nodes()
            self.assertIn(output_node, output_nodes)
            self.assertNotIn(tensor_node, output_nodes)  # 被dark_name_list过滤
            self.assertIn(normal_node, output_nodes)

    def test_get_trav_node(self):
        """测试获取遍历节点"""
        add_node = [MockDagNode("add", "__add__")]
        add_node[0].input_nodes = ["node1", "node2"]
        node = "node1"

        trav_node = TorchDAGAdapter._get_trav_node(add_node, node)
        self.assertEqual(trav_node, "node2")

    def test_get_trav_node_single_input(self):
        """测试单个输入节点情况"""
        add_node = [MockDagNode("add", "__add__")]
        add_node[0].input_nodes = ["node1"]
        node = "node1"

        trav_node = TorchDAGAdapter._get_trav_node(add_node, node)
        self.assertIsNone(trav_node)

    def test_get_llm_network_pattern_auto(self):
        """测试自动获取LLM网络模式"""
        norm_node1 = MockDagNode("norm1", "LayerNorm")
        norm_node2 = MockDagNode("norm2", "LayerNorm")
        norm_node3 = MockDagNode("norm3", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")

        test_nodes = [norm_node1, linear1, linear2, norm_node2, linear1, linear2, norm_node3]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, 'add_mhsa_norm_linears'):
                    result = self.adapter.get_llm_network_pattern_auto()
                    self.assertEqual(len(result), 5)  # 返回5个列表

    def test_add_mhsa_norm_linears_short_interval(self):
        """测试添加MHSA norm线性层（短间隔）"""
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        linear4 = MockDagNode("linear4", "Linear")

        test_nodes = \
            [MockDagNode("norm", "LayerNorm"), linear1, linear2, linear3, linear4, MockDagNode("end", "Linear")]

        mhsa_linear = []
        mhsa_o = []
        mhsa_ln = []

        with patch.object(self.adapter, 'node_list', test_nodes):
            self.adapter.add_mhsa_norm_linears(mhsa_linear, mhsa_o, mhsa_ln, [0, 5])
            self.assertEqual(len(mhsa_linear), 1)
            self.assertEqual(len(mhsa_o), 1)
            self.assertEqual(len(mhsa_ln), 1)

    def test_add_mhsa_norm_linears_long_interval(self):
        """测试添加MHSA norm线性层（长间隔）"""
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        linear4 = MockDagNode("linear4", "Linear")
        linear5 = MockDagNode("linear5", "Linear")
        linear6 = MockDagNode("linear6", "Linear")

        test_nodes = [MockDagNode("norm", "LayerNorm"), linear1, linear2, linear3,
                      linear4, linear5, linear6, MockDagNode("end", "Linear")]

        mhsa_linear = []
        mhsa_o = []
        mhsa_ln = []

        with patch.object(self.adapter, 'node_list', test_nodes):
            self.adapter.add_mhsa_norm_linears(mhsa_linear, mhsa_o, mhsa_ln, [0, 7])
            self.assertEqual(len(mhsa_linear), 1)
            self.assertEqual(len(mhsa_o), 1)
            self.assertEqual(len(mhsa_ln), 1)

    def test_get_all_interval_nodes_by_type(self):
        """测试获取指定类型的所有间隔节点"""
        stop_node = MockDagNode("stop", "Linear")
        trav_node = MockDagNode("trav", "Linear")

        result = self.adapter._get_all_interval_nodes_by_type(stop_node, trav_node)
        self.assertIsInstance(result, list)

    def test_get_all_interval_nodes_by_type_ignore_stop_node(self):
        """测试忽略停止节点的情况"""
        stop_node = MockDagNode("stop", "Linear")
        trav_node = MockDagNode("trav", "Linear")

        result = self.adapter._get_all_interval_nodes_by_type(stop_node, trav_node, ignore_stop_node=True)
        self.assertIsInstance(result, list)

    def test_find_latest_module_by_type_bfs(self):
        """测试BFS查找最新模块类型"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")

        search_node.input_nodes = [target_node]

        result = self.adapter._find_latest_module_by_type_bfs(search_node, "Linear")
        self.assertIsInstance(result, list)

    def test_find_latest_module_by_type_bfs_none_input(self):
        """测试BFS查找模块类型，输入为None"""
        result = self.adapter._find_latest_module_by_type_bfs(None, "Linear")
        self.assertEqual(result, [])

    def test_find_latest_module_by_type_bfs_stop_node(self):
        """测试BFS查找模块类型，遇到停止节点"""
        search_node = MockDagNode("search", "Linear")
        stop_node = MockDagNode("stop", "Stop")

        search_node.input_nodes = [stop_node]

        result = self.adapter._find_latest_module_by_type_bfs(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])

    def test_find_latest_module_by_type_recursive(self):
        """测试递归查找最新模块类型"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")

        search_node.input_nodes = [target_node]

        with patch.object(self.adapter, '_find_latest_module_by_type') as mock_find:
            mock_find.return_value = [target_node]
            result = self.adapter._find_latest_module_by_type(search_node, "Linear")
            self.assertIsInstance(result, list)

    def test_find_latest_module_by_type_none_input(self):
        """测试递归查找模块类型，输入为None"""
        result = self.adapter._find_latest_module_by_type(None, "Linear")
        self.assertEqual(result, [])

    def test_find_latest_module_by_type_conv2d_or_getitem(self):
        """测试遇到Conv2d或GetItem的情况"""
        search_node = MockDagNode("search", "Linear")
        conv_node = MockDagNode("conv", ModuleType.CONV2D)

        search_node.input_nodes = [conv_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear")
        self.assertEqual(result, [])

    def test_find_latest_module_by_type_stop_node_in_output(self):
        """测试递归查找模块类型，停止节点在输出中"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")
        stop_node = MockDagNode("stop", "Stop")

        search_node.input_nodes = [target_node]
        target_node.output_nodes = [stop_node]  # stop_node在target_node的输出中

        result = self.adapter._find_latest_module_by_type(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])  # 应该返回空列表

    def test_find_latest_module_by_type_stop_node_equal(self):
        """测试递归查找模块类型，当前节点等于停止节点"""
        search_node = MockDagNode("search", "Linear")
        stop_node = MockDagNode("stop", "Stop")

        search_node.input_nodes = [stop_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])  # 应该返回空列表

    def test_find_latest_module_by_type_found_target(self):
        """测试递归查找模块类型，找到目标节点"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")

        search_node.input_nodes = [target_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear")
        self.assertEqual(result, [target_node])  # 应该返回包含目标节点的列表

    def test_find_latest_module_by_type_branch_aware_true(self):
        """测试branch_aware=True的情况"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")
        other_node = MockDagNode("other", "Linear")

        search_node.input_nodes = [target_node]

        with patch.object(self.adapter, '_find_latest_module_by_type') as mock_find:
            mock_find.return_value = [target_node]
            result = self.adapter._find_latest_module_by_type(search_node, "Linear", branch_aware=True)
            self.assertIsInstance(result, list)

    def test_get_ffn_pattern_and_ln_list(self):
        """测试获取FFN模式和LN列表"""
        matmul_list = [MockDagNode("linear1", "Linear"), MockDagNode("linear2", "Linear")]
        ln_list = [MockDagNode("norm", "LayerNorm")]
        stop_node = MockDagNode("stop", "Linear")
        ffn_pattern = []
        ffn_ln_list = []

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.'
                   'model_structure_process.StructureProcess.is_ffn_matmul',
                   return_value=True):
            TorchDAGAdapter._get_ffn_pattern_and_ln_list((2, matmul_list), ln_list,
                                                         stop_node, ffn_pattern, ffn_ln_list)
            self.assertEqual(len(ffn_pattern), 1)
            self.assertEqual(len(ffn_ln_list), 1)

    def test_get_ffn_pattern_and_ln_list_reverse_condition(self):
        """测试需要反转列表的情况"""
        matmul_list = [MockDagNode("linear1", "Linear"), MockDagNode("linear2", "Linear")]
        ln_list = [MockDagNode("norm", "LayerNorm")]
        stop_node = MockDagNode("stop", "OtherType")  # 不是Linear，应该反转
        ffn_pattern = []
        ffn_ln_list = []

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.'
                   'model_structure_process.StructureProcess.is_ffn_matmul',
                   return_value=True):
            TorchDAGAdapter._get_ffn_pattern_and_ln_list((2, matmul_list), ln_list,
                                                         stop_node, ffn_pattern, ffn_ln_list)
            self.assertEqual(len(ffn_pattern), 1)
            self.assertEqual(len(ffn_ln_list), 1)

    def test_get_linear_linear_subgraph_structured_linear_order_false(self):
        """测试get_linear_linear_subgraph，structured_linear_order=False"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name', return_value=Mock()):
                    with patch.object(self.adapter, '_find_latest_module_by_type_bfs', return_value=[linear1]):
                        result = self.adapter.get_linear_linear_subgraph(structured_linear_order=False)
                        self.assertIsInstance(result, defaultdict)

    def test_get_linear_linear_subgraph_more_than_four_linears(self):
        """测试超过4个线性层的情况"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        linear4 = MockDagNode("linear4", "Linear")
        linear5 = MockDagNode("linear5", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, linear3, linear4, linear5, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name', return_value=Mock()):
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)

    def test_get_linear_linear_subgraph_gate_down_case(self):
        """测试gate down情况"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name') as mock_find:
                    mock_module = Mock()
                    mock_module.weight = Mock()
                    mock_module.weight.size.return_value = (10, 10)  # 模拟权重尺寸
                    mock_find.return_value = mock_module
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)


class TestTorchDAGAdapterEdgeCases(unittest.TestCase):
    """测试TorchDAGAdapter的边界情况"""

    def setUp(self):
        """测试前准备"""
        self.model = nn.Linear(10, 10)  # 使用简单的线性模型
        self.mock_dag_hook = Mock()
        self.mock_dag_hook.dag_node_list = []

        with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                   return_value=self.mock_dag_hook):
            self.adapter = TorchDAGAdapter(self.model, dummy_input=torch.randn(1, 10))

    def test_empty_node_list(self):
        """测试空节点列表"""
        with patch.object(self.adapter, 'node_list', []):
            names = self.adapter.get_node_name_in_order()
            self.assertEqual(names, [])

            result = self.adapter.get_name_type_dict_in_order()
            self.assertEqual(result, {})

            topology = self.adapter.get_network_topology()
            self.assertEqual(topology, [])

    def test_single_node(self):
        """测试单个节点"""
        single_node = MockDagNode("single", "Linear")
        with patch.object(self.adapter, 'node_list', [single_node]):
            names = self.adapter.get_node_name_in_order()
            self.assertEqual(names, ["single"])

            result = self.adapter.get_name_type_dict_in_order()
            self.assertEqual(result, {"single": "Linear"})

    def test_no_norm_nodes(self):
        """测试没有norm节点的情况"""
        with patch.object(self.adapter, 'norm_nodes', []):
            result = self.adapter.get_norm_linear_subgraph()
            self.assertEqual(result, defaultdict(list))

    def test_cuda_device(self):
        """测试CUDA设备情况"""
        if torch.cuda.is_available():
            cuda_model = nn.Linear(10, 10).cuda()
            dummy_input = torch.randn(1, 10).cuda()

            mock_dag_hook = Mock()
            mock_dag_hook.dag_node_list = [MockDagNode("linear", "Linear")]

            with patch('msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.dag_torch_hook.DagTorchHook',
                       return_value=mock_dag_hook):
                adapter = TorchDAGAdapter(cuda_model, dummy_input=dummy_input)
                self.assertEqual(adapter._dummy_input.device.type, 'cuda')

    def test_find_latest_module_by_type_branch_aware_true(self):
        """测试branch_aware=True的情况"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")
        other_node = MockDagNode("other", "Linear")

        search_node.input_nodes = [target_node]

        with patch.object(self.adapter, '_find_latest_module_by_type') as mock_find:
            mock_find.return_value = [target_node]
            result = self.adapter._find_latest_module_by_type(search_node, "Linear", branch_aware=True)
            self.assertIsInstance(result, list)

    def test_find_latest_module_by_type_stop_node_in_output(self):
        """测试停止节点在输出中的情况"""
        search_node = MockDagNode("search", "Linear")
        stop_node = MockDagNode("stop", "Stop")
        output_node = MockDagNode("output", "Linear")

        output_node.output_nodes = [stop_node]
        search_node.input_nodes = [output_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])

    def test_find_latest_module_by_type_with_conv2d_or_getitem(self):
        """测试遇到Conv2d或GetItem的情况"""
        search_node = MockDagNode("search", "Linear")
        conv_node = MockDagNode("conv", ModuleType.CONV2D)

        search_node.input_nodes = [conv_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear")
        self.assertEqual(result, [])

    def test_find_input_nodes_with_tensor_base(self):
        """测试查找输入节点，包含tensor base类型"""
        input_node = MockDagNode("input", "Input")
        tensor_node = MockDagNode("_TensorBase.__getitem__", "_TensorBase.__getitem__")  # 应该被过滤
        normal_node = MockDagNode("normal", "Linear")

        input_node.inputs = [Mock()]
        input_node.inputs[0].dag_node_from = None  # 没有来源，是输入

        tensor_node.inputs = [Mock()]
        tensor_node.inputs[0].dag_node_from = Mock()

        normal_node.inputs = [Mock()]
        normal_node.inputs[0].dag_node_from = None

        with patch.object(self.adapter, 'node_list', [input_node, tensor_node, normal_node]):
            input_nodes = self.adapter._find_input_nodes()
            self.assertIn(input_node, input_nodes)
            self.assertNotIn(tensor_node, input_nodes)  # 被dark_name_list过滤
            self.assertIn(normal_node, input_nodes)

    def test_find_output_nodes_with_tensor_base(self):
        """测试查找输出节点，包含tensor base类型"""
        output_node = MockDagNode("output", "Output")
        tensor_node = MockDagNode("_TensorBase.size", "_TensorBase.size")  # 应该被过滤
        normal_node = MockDagNode("normal", "Linear")

        output_node.outputs = [Mock()]
        output_node.outputs[0].dag_nodes_to = []  # 没有去向，是输出

        tensor_node.outputs = [Mock()]
        tensor_node.outputs[0].dag_nodes_to = [Mock()]

        normal_node.outputs = [Mock()]
        normal_node.outputs[0].dag_nodes_to = []  # 没有去向，是输出

        with patch.object(self.adapter, 'node_list', [output_node, tensor_node, normal_node]):
            output_nodes = self.adapter._find_output_nodes()
            self.assertIn(output_node, output_nodes)
            self.assertNotIn(tensor_node, output_nodes)  # 被dark_name_list过滤
            self.assertIn(normal_node, output_nodes)

    def test_get_linear_linear_subgraph_more_than_two_linears(self):
        """测试get_linear_linear_subgraph，2 < num_of_linears <= 4"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        linear4 = MockDagNode("linear4", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, linear3, linear4, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name', return_value=Mock()):
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)

    def test_get_linear_linear_subgraph_two_linears_gate_down(self):
        """测试get_linear_linear_subgraph，2个线性层，gate down情况"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name') as mock_find:
                    # 模拟gate down情况：origin_layer.weight.size(1) == target_layer.weight.size(0)
                    mock_module = Mock()
                    mock_module.weight = Mock()
                    mock_module.weight.size.return_value = (10, 10)  # 模拟尺寸匹配
                    mock_find.return_value = mock_module
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)

    def test_get_kv_linears_even_interval(self):
        """测试get_kv_linears，偶数区间情况"""
        norm_node = MockDagNode("norm", "LayerNorm")
        linear1 = MockDagNode("linear1", "Linear")
        linear2 = MockDagNode("linear2", "Linear")
        linear3 = MockDagNode("linear3", "Linear")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, linear1, linear2, linear3, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                kv_linears, num_kv = self.adapter.get_kv_linears()
                self.assertIsInstance(kv_linears, list)
                self.assertIsInstance(num_kv, int)

    def test_get_linear_linear_subgraph_no_target_node(self):
        """测试get_linear_linear_subgraph，没有目标节点的情况"""
        norm_node = MockDagNode("norm", "LayerNorm")
        next_norm = MockDagNode("norm2", "LayerNorm")

        test_nodes = [norm_node, next_norm]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch.object(self.adapter, 'norm_nodes', ['layernorm']):
                with patch.object(self.adapter, '_find_module_by_name', return_value=Mock()):
                    result = self.adapter.get_linear_linear_subgraph()
                    self.assertIsInstance(result, defaultdict)

    def test_get_mhsa_pattern_with_multiple_conditions(self):
        """测试get_mhsa_pattern函数，多个条件分支"""
        # 测试条件: len(add_node) != 1
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node1 = MockDagNode("add1", "__add__")
        mock_add_node2 = MockDagNode("add2", "__add__")
        mock_node.output_nodes = [mock_add_node1, mock_add_node2]  # 2个add节点

        test_nodes = [mock_node, mock_add_node1, mock_add_node2]

        with (patch.object(self.adapter, 'node_list', test_nodes)):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                    as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()  # 应该不被调用
                self.assertEqual(len(qkv_list), 0)
                self.assertEqual(len(proj_list), 0)

    def test_get_mhsa_pattern_with_len_inputs_not_2(self):
        """测试get_mhsa_pattern函数，add_node[0].inputs长度不为2"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock()]  # 只有1个输入
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]
        mock_add_node.input_nodes = [mock_node]  # 只有1个输入节点

        test_nodes = [mock_node, mock_add_node, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                    as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()  # 应该不被调用
                self.assertEqual(len(qkv_list), 0)
                self.assertEqual(len(proj_list), 0)

    def test_get_mhsa_pattern_with_no_other_branch(self):
        """测试get_mhsa_pattern函数，没有other_branch"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]  # 2个输入
        mock_add_node.input_nodes = [mock_node, MockDagNode("other_input", "Linear")]

        mock_node.output_nodes = [mock_add_node]  # 没有other_branch

        test_nodes = [mock_node, mock_add_node, MockDagNode("other_input", "Linear")]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.model_structure_process.'
                    'StructureProcess.mhsa_matmul_process') \
                    as mock_process:
                qkv_list, proj_list = self.adapter.get_mhsa_pattern()
                mock_process.assert_not_called()  # 应该不被调用
                self.assertEqual(len(qkv_list), 0)
                self.assertEqual(len(proj_list), 0)

    def test_get_llama_ffn_ln_pattern_with_multiple_conditions(self):
        """测试get_llama_ffn_ln_pattern函数，多个条件分支"""
        # 测试条件: len(add_node) != 1
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node1 = MockDagNode("add1", "__add__")
        mock_add_node2 = MockDagNode("add2", "__add__")
        mock_other_branch = MockDagNode("other", "Linear")
        mock_node.output_nodes = [mock_add_node1, mock_add_node2, mock_other_branch]  # 2个add节点

        test_nodes = [mock_node, mock_add_node1, mock_add_node2, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                    as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()  # 应该不被调用
                self.assertEqual(len(ffn_pattern), 0)
                self.assertEqual(len(ffn_ln_list), 0)

    def test_get_llama_ffn_ln_pattern_with_len_other_branch_not_1(self):
        """测试get_llama_ffn_ln_pattern函数，other_branch长度不为1"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock(), Mock()]  # 2个输入
        mock_other_branch1 = MockDagNode("other1", "Linear")
        mock_other_branch2 = MockDagNode("other2", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch1, mock_other_branch2]  # 2个other_branch

        test_nodes = [mock_node, mock_add_node, mock_other_branch1, mock_other_branch2]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                    as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()  # 应该不被调用
                self.assertEqual(len(ffn_pattern), 0)
                self.assertEqual(len(ffn_ln_list), 0)

    def test_get_llama_ffn_ln_pattern_with_len_add_inputs_not_2(self):
        """测试get_llama_ffn_ln_pattern函数，add_node输入长度不为2"""
        mock_node = MockDagNode("test_node", "Linear")
        mock_add_node = MockDagNode("add", "__add__")
        mock_add_node.inputs = [Mock()]  # 只有1个输入
        mock_other_branch = MockDagNode("other", "Linear")

        mock_node.output_nodes = [mock_add_node, mock_other_branch]

        test_nodes = [mock_node, mock_add_node, mock_other_branch]

        with patch.object(self.adapter, 'node_list', test_nodes):
            with patch(
                    'msmodelslim.pytorch.llm_ptq.anti_outlier.dag_utils.torch_dag_adapter.'
                    'TorchDAGAdapter._get_ffn_pattern_and_ln_list') \
                    as mock_get:
                ffn_pattern, ffn_ln_list = self.adapter.get_llama_ffn_ln_pattern()
                mock_get.assert_not_called()  # 应该不被调用
                self.assertEqual(len(ffn_pattern), 0)
                self.assertEqual(len(ffn_ln_list), 0)

    def test_find_latest_module_by_type_with_stop_node_in_output_nodes(self):
        """测试_find_latest_module_by_type函数，stop_node在node.output_nodes中"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")
        stop_node = MockDagNode("stop", "Stop")

        search_node.input_nodes = [target_node]
        target_node.output_nodes = [stop_node]  # stop_node在target_node的输出中

        # 验证stop_node在target_node的输出节点列表中
        self.assertIn(stop_node, list(target_node.output_nodes))

        result = self.adapter._find_latest_module_by_type(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])  # 应该返回空列表，因为stop_node在输出中

    def test_find_latest_module_by_type_with_stop_node_equal_current(self):
        """测试_find_latest_module_by_type函数，当前节点等于stop_node"""
        search_node = MockDagNode("search", "Linear")
        stop_node = MockDagNode("stop", "Stop")

        search_node.input_nodes = [stop_node]  # 直接连接到stop_node

        result = self.adapter._find_latest_module_by_type(search_node, "Linear", stop_node=stop_node)
        self.assertEqual(result, [])  # 应该返回空列表，因为当前节点就是stop_node

    def test_find_latest_module_by_type_with_found_target_node(self):
        """测试_find_latest_module_by_type函数，找到目标节点"""
        search_node = MockDagNode("search", "Linear")
        target_node = MockDagNode("target", "Linear")
        target_node.op_type = "Linear"

        search_node.input_nodes = [target_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear")
        self.assertEqual(result, [target_node])  # 应该返回包含目标节点的列表

    def test_find_latest_module_by_type_with_conv2d_or_getitem_stop(self):
        """测试_find_latest_module_by_type函数，遇到Conv2d或GetItem时停止"""
        search_node = MockDagNode("search", "Linear")
        conv_node = MockDagNode("conv", ModuleType.CONV2D)

        search_node.input_nodes = [conv_node]

        result = self.adapter._find_latest_module_by_type(search_node, "Linear")
        self.assertEqual(result, [])  # 应该返回空列表，因为遇到Conv2d停止


if __name__ == '__main__':
    unittest.main()